import requests
import json
import os
import asyncio
import aiohttp

# 接口地址(主)
_api_url="http://14.103.236.44:18001/api/ai/detect"
# 接口地址(备)
#_api_url_backup="http://116.204.115.87:19001/api/ai/detect"


_detect_static_enum_dict = {
    # 英文放到中文前面，中文放到英文后面
    "001fire": "001火灾检测",
    "002garbage": "002垃圾检测",
    "003faceDetect": "003人脸检测",
    "004callPhone": "004打电话识别",
    "005fighting": "005人员打架",
    "006safeHat": "006安全帽识别",
    "007smoke": "007吸烟识别",
    "008walkerMan": "008行人识别",
    "009maskDetect": "009口罩检测",
    "010fallDetect": "010人员跌倒识别",
    "011crossFenceDetect": "011翻越围栏识别",
    "012streetVendorDetect": "012街头流动商贩识别"
}


def upload_image_for_detection(image_path, model_code="003faceDetect"):
    """
    上传图片到AI检测接口并获取检测结果
    
    参数:
        image_path (str): 图片文件的路径
        api_url (str): AI检测接口的URL
        model_code (str): 模型编码，
    
    返回:
        dict: 检测结果的JSON对象
    """
    try:
        # 检查图片文件是否存在
        if not os.path.exists(image_path):
            print(f"错误: 图片文件不存在 - {image_path}")
            return None
        
        print(f"正在上传图片: {image_path}")
        print(f"目标API: {_api_url}")
        print(f"使用模型编码: {model_code}")
        
        # 准备文件数据
        with open(image_path, 'rb') as f:
            files = {
                'file': (os.path.basename(image_path), f, 'image/jpeg')
            }
            
            # 表单数据，使用model_code参数
            data = {
                'model_code': model_code  # 模型编码参数
            }
            
            # 发送POST请求
            response = requests.post(_api_url, files=files, data=data)
            
            # 检查响应状态
            response.raise_for_status()
            
            # 尝试解析JSON响应
            try:
                result = response.json()
                return result
            except json.JSONDecodeError:
                print(f"警告: 无法解析响应为JSON，原始响应内容: {response.text}")
                return {"raw_response": response.text}
                
    except requests.exceptions.RequestException as e:
        print(f"请求出错: {e}")
        return None
    except Exception as e:
        print(f"发生错误: {e}")
        return None

def print_detection_result(result):
    """
    打印检测结果
    
    参数:
        result (dict): 检测结果的JSON对象
    """
    if not result:
        print("没有有效的检测结果")
        return
    
    # 美化打印JSON结果
    print("\n检测结果:")
    print(json.dumps(result, ensure_ascii=False, indent=2))
    
    # 如果有检测结果，尝试提取关键信息
    if isinstance(result, dict):
        # 打印基本信息
        if 'original_url' in result:
            print(f"\n原始图片URL: {result['original_url']}")
        
        if 'detected_url' in result:
            print(f"检测结果URL: {result['detected_url']}")
        
        # 打印图片尺寸
        if 'image_size' in result:
            size = result['image_size']
            print(f"图片尺寸: {size.get('width', 0)}x{size.get('height', 0)}")
        elif 'width' in result and 'height' in result:
            print(f"图片尺寸: {result['width']}x{result['height']}")
        
        # 打印使用的模型信息
        if 'model_used' in result:
            model_info = result['model_used']
            if isinstance(model_info, dict):
                print(f"使用模型: {model_info.get('name', '未知')} ({model_info.get('code', 'unknown')})")
                if model_info.get('is_fallback', False):
                    print("提示: 由于指定模型不可用，系统自动使用了备选模型")
            else:
                print(f"使用模型: {model_info}")
        
        # 打印消息
        if 'message' in result:
            print(f"\n消息: {result['message']}")
        
        # 打印检测到的目标
        if 'detections' in result:
            detections = result['detections']
            print(f"\n检测到的目标数量: {len(detections)}")
            
            # 打印每个检测目标的详细信息
            for i, detection in enumerate(detections):
                print(f"\n目标 {i+1}:")
                if 'class' in detection:
                    print(f"  类别: {detection['class']}")
                if 'confidence' in detection:
                    print(f"  置信度: {detection['confidence']:.2f}")
                if 'bbox' in detection:
                    bbox = detection['bbox']
                    print(f"  边界框: xmin={bbox.get('xmin', 0)}, ymin={bbox.get('ymin', 0)}, "
                          f"xmax={bbox.get('xmax', 0)}, ymax={bbox.get('ymax', 0)}")

def detect_and_save_result(image_path, model_code="003faceDetect", save_json=True):
    """
    检测图片并可选保存结果到JSON文件
    
    参数:
        image_path (str): 图片路径
        model_code (str): 模型编码
        save_json (bool): 是否保存结果到JSON文件
    
    返回:
        dict: 检测结果
    """
    # 调用API进行检测
    detection_result = upload_image_for_detection(image_path, model_code=model_code)
    
    # 打印检测结果
    print_detection_result(detection_result)
    
    # 保存结果到JSON文件
    if save_json and detection_result:
        # 生成保存文件名
        base_name = os.path.splitext(os.path.basename(image_path))[0]
        if not os.path.exists("./detect_out"):
            os.makedirs("./detect_out", exist_ok=True)
        json_file = f".\detect_out\detection_result_{base_name}.json"
        
        # 保存JSON文件
        with open(json_file, 'w', encoding='utf-8') as f:
            json.dump(detection_result, f, ensure_ascii=False, indent=2)
        
        # 根据模型类型显示不同的保存消息
        if model_code == "003faceDetect":
            print(f"\n人脸检测结果已保存到: {json_file}")
        elif model_code == "002garbage":
            print(f"\n垃圾检测结果已保存到: {json_file}")
        elif model_code == "001fire":
            print(f"\n火灾检测结果已保存到: {json_file}")
        else:
            print(f"\n检测结果已保存到: {json_file}")
    
    # 根据模型类型显示不同的完成消息
    if model_code == "003faceDetect":
        print("\n人脸检测完成")
    elif model_code == "002garbage":
        print("\n垃圾检测完成")
    elif model_code == "001fire":
        print("\n火灾检测完成")
    else:
        print("\n检测完成")
    
    return detection_result

async def detect_image(image_path, model_code="003faceDetect"):
    """
    异步检测图片并返回结果
    
    参数:
        image_path (str): 图片路径
        model_code (str): 模型编码
    
    返回:
        dict: 检测结果的JSON对象
    """
    try:
        # 检查图片文件是否存在
        if not os.path.exists(image_path):
            print(f"错误: 图片文件不存在 - {image_path}")
            return None
        
        print(f"正在上传图片: {image_path}")
        print(f"目标API: {_api_url}")
        print(f"使用模型编码: {model_code}")
        
        # 准备文件数据
        with open(image_path, 'rb') as f:
            image_data = f.read()
        
        # 构建multipart/form-data请求
        async with aiohttp.ClientSession() as session:
            data = aiohttp.FormData()
            data.add_field('file', 
                          image_data, 
                          filename=os.path.basename(image_path), 
                          content_type='image/jpeg')
            data.add_field('model_code', model_code)
            
            # 发送POST请求
            async with session.post(_api_url, data=data) as response:
                # 检查响应状态
                response.raise_for_status()
                
                # 尝试解析JSON响应
                try:
                    result = await response.json()
                    return result
                except json.JSONDecodeError:
                    print(f"警告: 无法解析响应为JSON，原始响应内容: {await response.text()}")
                    return {"raw_response": await response.text()}
    
    except aiohttp.ClientError as e:
        print(f"请求出错: {e}")
        return None
    except Exception as e:
        print(f"发生错误: {e}")
        return None

async def run_async_demo():
    """
    运行异步演示示例
    """
    print("[异步垃圾检测示例]")
    garbage_image_path = r".\images\002garbage\test_garbage1.jpg"
    
    # 使用异步方法检测垃圾
    result = await detect_image(garbage_image_path, model_code="002garbage")
    
    # 打印检测结果
    print_detection_result(result)
    
    # 保存结果到JSON文件
    if result:
        # 生成保存文件名
        base_name = os.path.splitext(os.path.basename(garbage_image_path))[0]
        if not os.path.exists("./detect_out"):
            os.makedirs("./detect_out", exist_ok=True)
        json_file = f".\detect_out\detection_result_{base_name}.json"
        
        # 保存JSON文件
        with open(json_file, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False, indent=2)
        
        print(f"\n垃圾检测结果已保存到: {json_file}")
    
    print("\n异步垃圾检测完成")

def test_fire_detection(image_path):
    """测试火灾检测模型"""
    print(f"\n开始测试火灾检测: {image_path}")
    result = detect_and_save_result(image_path, model_code="001fire")
    return result

def test_garbage_detection(image_path):
    """测试垃圾检测模型"""
    print(f"\n开始测试垃圾检测: {image_path}")
    result = detect_and_save_result(image_path, model_code="002garbage")
    return result

def test_face_detection(image_path):
    """测试人脸检测模型"""
    print(f"\n开始测试人脸检测: {image_path}")
    result = detect_and_save_result(image_path, model_code="003faceDetect")
    return result

def test_callphone_detection(image_path):
    """测试打电话识别模型"""
    print(f"\n开始测试手机打电话检测: {image_path}")
    result = detect_and_save_result(image_path, model_code="004callPhone")
    return result

def test_fighting_detection(image_path):
    """测试人员打架检测模型"""
    print(f"\n开始测试人员打架检测: {image_path}")
    result = detect_and_save_result(image_path, model_code="005fighting")
    return result

def test_safehat_detection(image_path):
    """测试安全帽识别模型"""
    print(f"\n开始测试安全帽识别: {image_path}")
    result = detect_and_save_result(image_path, model_code="006safeHat")
    return result

def test_smoke_detection(image_path):
    """测试吸烟识别模型"""
    print(f"\n开始测试吸烟识别: {image_path}")
    result = detect_and_save_result(image_path, model_code="007smoke")
    return result

def test_walker_detection(image_path):
    """测试行人识别模型"""
    print(f"\n开始测试行人识别: {image_path}")
    result = detect_and_save_result(image_path, model_code="008walkerMan")
    return result

def test_mask_detection(image_path):
    """测试口罩检测模型"""
    print(f"\n开始测试口罩检测: {image_path}")
    result = detect_and_save_result(image_path, model_code="009maskDetect")
    return result

def test_fall_detection(image_path):
    """测试人员跌倒识别模型"""
    print(f"\n开始测试人员跌倒识别: {image_path}")
    result = detect_and_save_result(image_path, model_code="010fallDetect")
    return result

def test_crossfence_detection(image_path):
    """测试翻越围栏识别模型"""
    print(f"\n开始测试翻越围栏识别: {image_path}")
    result = detect_and_save_result(image_path, model_code="011crossFenceDetect")
    return result

def test_streetvendor_detection(image_path):
    """测试街头流动商贩识别模型"""
    print(f"\n开始测试街头流动商贩识别: {image_path}")
    result = detect_and_save_result(image_path, model_code="012streetVendorDetect")
    return result

if __name__ == "__main__":
    import argparse
    
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser(description='VisionForge SDK 模型测试工具')
    
    # 为每种模型类型添加图片路径参数
    parser.add_argument('--fire-image', type=str, help='火灾检测测试图片路径')
    parser.add_argument('--garbage-image', type=str, help='垃圾检测测试图片路径')
    parser.add_argument('--face-image', type=str, help='人脸检测测试图片路径')
    parser.add_argument('--callphone-image', type=str, help='打电话识别测试图片路径')
    parser.add_argument('--fighting-image', type=str, help='人员打架检测测试图片路径')
    parser.add_argument('--safehat-image', type=str, help='安全帽识别测试图片路径')
    parser.add_argument('--smoke-image', type=str, help='吸烟识别测试图片路径')
    parser.add_argument('--walker-image', type=str, help='行人识别测试图片路径')
    parser.add_argument('--mask-image', type=str, help='口罩检测测试图片路径')
    parser.add_argument('--fall-image', type=str, help='人员跌倒识别测试图片路径')
    parser.add_argument('--crossfence-image', type=str, help='翻越围栏识别测试图片路径')
    parser.add_argument('--streetvendor-image', type=str, help='街头流动商贩识别测试图片路径')
    
    # 添加运行所有模型测试的参数
    parser.add_argument('--all', action='store_true', help='运行所有模型测试（需要提供所有图片路径）')
    
    # 解析参数
    args = parser.parse_args()
    
    # 测试结果统计
    tests_run = 0
    tests_succeeded = 0
    
    # 模型测试映射
    model_tests = [
        (args.fire_image, test_fire_detection, "001fire - 火灾检测"),
        (args.garbage_image, test_garbage_detection, "002garbage - 垃圾检测"),
        (args.face_image, test_face_detection, "003faceDetect - 人脸检测"),
        (args.callphone_image, test_callphone_detection, "004callPhone - 打电话识别"),
        (args.fighting_image, test_fighting_detection, "005fighting - 人员打架"),
        (args.safehat_image, test_safehat_detection, "006safeHat - 安全帽识别"),
        (args.smoke_image, test_smoke_detection, "007smoke - 吸烟识别"),
        (args.walker_image, test_walker_detection, "008walkerMan - 行人识别"),
        (args.mask_image, test_mask_detection, "009maskDetect - 口罩检测"),
        (args.fall_image, test_fall_detection, "010fallDetect - 人员跌倒识别"),
        (args.crossfence_image, test_crossfence_detection, "011crossFenceDetect - 翻越围栏识别"),
        (args.streetvendor_image, test_streetvendor_detection, "012streetVendorDetect - 街头流动商贩识别")
    ]
    
    print("===== VisionForge SDK 多模型检测演示 =====")
    print(f"可用模型数量: {len(_detect_static_enum_dict)}")
    print("模型列表:")
    for code, name in _detect_static_enum_dict.items():
        print(f"  - {code}: {name}")
    print("\n")
    
    # 运行指定的模型测试
    any_test_run = False
    for image_path, test_func, model_name in model_tests:
        if args.all or image_path:
            any_test_run = True
            tests_run += 1
            try:
                result = test_func(image_path)
                if result:
                    tests_succeeded += 1
                    print(f"✓ {model_name} 测试成功")
                else:
                    print(f"✗ {model_name} 测试失败")
            except Exception as e:
                print(f"✗ {model_name} 测试异常: {str(e)}")
    
    # 如果没有指定任何测试，运行默认测试
    if not any_test_run:
        print("未指定测试模型，运行默认测试...")
        # 默认测试手机打电话检测
        default_callphone_path = r"E:\PyProject_yywl\01ultralytics-main-garbage\SDKDemo\CSharp\AiProjectCSharp\AiProjectCSharp\bin\Debug\TestImg\004打电话识别004callPhone\打电话1.jpg"
        test_callphone_detection(default_callphone_path)
        
        # 尝试测试人员打架检测
        try:
            default_fighting_path = r"E:\PyProject_yywl\01ultralytics-main-garbage\SDKDemo\CSharp\AiProjectCSharp\AiProjectCSharp\bin\Debug\TestImg\005人员打架005fight\打架1.jpg"
            if os.path.exists(default_fighting_path):
                test_fighting_detection(default_fighting_path)
        except Exception as e:
            print(f"默认人员打架检测失败: {str(e)}")
    
    # 打印测试统计信息
    if tests_run > 0:
        print(f"\n===== 测试统计 =====")
        print(f"总测试数: {tests_run}")
        print(f"成功测试数: {tests_succeeded}")
        print(f"失败测试数: {tests_run - tests_succeeded}")
    
    print(f"\n所有检测演示完成!")
    print("\n使用帮助:")
    print("  python 'VisionForge SDK_python.py' --model-image <图片路径>")
    print("  例如: python 'VisionForge SDK_python.py' --fire-image fire.jpg --face-image face.jpg")